{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f705f4be70e9"
      },
      "outputs": [],
      "source": [
        "# Copyright 2025 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eebd5c14c254"
      },
      "source": [
        "# Task Planner Agent"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a4d2041c00a9"
      },
      "source": [
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/agents/genai-experience-concierge/agent-design-patterns/task-planner.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Fagents%2Fgenai-experience-concierge%2Fagent-design-patterns%2Ftask-planner.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/agents/genai-experience-concierge/agent-design-patterns/task-planner.ipynb\">\n",
        "      <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/agents/genai-experience-concierge/agent-design-patterns/task-planner.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://raw.githubusercontent.com/primer/octicons/refs/heads/main/icons/mark-github-24.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>\n",
        "\n",
        "<div style=\"clear: both;\"></div>\n",
        "\n",
        "<b>Share to:</b>\n",
        "\n",
        "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/agents/genai-experience-concierge/agent-design-patterns/task-planner.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/agents/genai-experience-concierge/agent-design-patterns/task-planner.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/agents/genai-experience-concierge/agent-design-patterns/task-planner.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/agents/genai-experience-concierge/agent-design-patterns/task-planner.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/agents/genai-experience-concierge/agent-design-patterns/task-planner.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
        "</a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dbfe0a3c85ab"
      },
      "source": [
        "| | |\n",
        "|-|-|\n",
        "|Author(s) | [Pablo Gaeta](https://github.com/pablofgaeta) |"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e7a8a0951f45"
      },
      "source": [
        "## Overview\n",
        "\n",
        "### Introduction\n",
        "\n",
        "This notebook demonstrates an implementation of a task planner agent (similar to [\"Deep Research\"](https://gemini.google/overview/deep-research)) This is a multi-agent architecture useful for tasks requiring more complex reasoning, planning, and multi-tool use.\n",
        "\n",
        "This architecture is often much slower than single-agent designs because a single turn can consist of a large number of LLM calls and tool usage. This demo is particularly slow because the \"Executor\" agent only supports linear plans and executes each task in parallel. There is research on alternative approaches such as [LLM Compiler](https://arxiv.org/abs/2312.04511) that attempt to improve this design by constructing DAGs to enable parallel task execution.\n",
        "\n",
        "The \"Executor\" agent in this demo is a Gemini model equipped with the Google Search Grounding Tool ([documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/ground-with-google-search)) to enable live web search while executing tasks.\n",
        "\n",
        "### Key Components\n",
        "\n",
        "The task planner agent is built around several key components:\n",
        "\n",
        "* **Language Model:** Gemini is used for natural language understanding, function calling, and response generation for multiple agents.\n",
        "* **State Management:** LangGraph manages the conversation flow and maintains the session state, including conversation history and generated/executed plans.\n",
        "* **Planner Node:** Generates a plan or a direct response to the user input. The plan consists of a sequence of tasks to be executed.\n",
        "* **Executor Node:** Executes the tasks defined in the plan, typically using tools like Google Search to gather information.\n",
        "* **Reflector Node:** Analyzes the results of the executed plan and the user's input to determine the next action - either generating a new plan for further execution or formulating a final response to the user.\n",
        "\n",
        "### Workflow\n",
        "\n",
        "The agent operates through the following workflow:\n",
        "\n",
        "1. The **Planner** receives user input and either (1) responds directly to simple queries (e.g. \"Hi\") or (2) generates a research plan, including list of tasks to execute.\n",
        "1. The **Executor** receives the plan and uses its tools to perform each task and update the plan with the executed task result.\n",
        "1. The **Reflector** reviews the executed plan and either (1) generates a final response to the user or (2) generates a new plan and jumps back to step 2."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "93b9c58f24d9"
      },
      "source": [
        "## Get Started"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d5027929de8f"
      },
      "source": [
        "### Install dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "aa7abf5449af"
      },
      "outputs": [],
      "source": [
        "%pip install -q google-genai langgraph langgraph-checkpoint pydantic tenacity"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f42d12d15616"
      },
      "source": [
        "### Restart runtime\n",
        "\n",
        "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.\n",
        "\n",
        "The restart might take a minute or longer. After it's restarted, continue to the next step."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "06fd78d27773"
      },
      "outputs": [],
      "source": [
        "# import IPython\n",
        "\n",
        "# app = IPython.Application.instance()\n",
        "# app.kernel.do_shutdown(True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e114f5653870"
      },
      "source": [
        "### Authenticate your notebook environment (Colab only)\n",
        "\n",
        "If you're running this notebook on Google Colab, run the cell below to authenticate your environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "911453311a5d"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "if \"google.colab\" in sys.modules:\n",
        "    from google.colab import auth\n",
        "\n",
        "    auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0724a3d2c4f9"
      },
      "source": [
        "## Notebook parameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "47ad1afc6f31"
      },
      "outputs": [],
      "source": [
        "# Use the environment variable if the user doesn't provide Project ID.\n",
        "import os\n",
        "\n",
        "PROJECT_ID = \"[your-project-id]\"  # @param {type: \"string\", placeholder: \"[your-project-id]\", isTemplate: true}\n",
        "if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\":\n",
        "    PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
        "\n",
        "REGION = \"us-central1\"  # @param {type:\"string\"}\n",
        "PLANNER_MODEL_NAME = \"gemini-2.0-flash-001\"  # @param {type:\"string\"}\n",
        "REFLECTOR_MODEL_NAME = \"gemini-2.0-flash-001\"  # @param {type:\"string\"}\n",
        "EXECUTOR_MODEL_NAME = \"gemini-2.0-flash-001\"  # @param {type:\"string\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b204dd9a3a93"
      },
      "source": [
        "## Define the Task Planner Agent"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6a04ecba1630"
      },
      "source": [
        "### Import dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "db0053639f69"
      },
      "outputs": [],
      "source": [
        "from collections.abc import AsyncGenerator\n",
        "import datetime\n",
        "from typing import Literal, TypedDict\n",
        "import uuid\n",
        "\n",
        "from IPython import display as ipd\n",
        "from google import genai\n",
        "from google.genai import errors as genai_errors\n",
        "from google.genai import types as genai_types\n",
        "from langchain_core.runnables import config as lc_config\n",
        "from langgraph import graph\n",
        "from langgraph import types as lg_types\n",
        "from langgraph.checkpoint import memory\n",
        "from langgraph.config import get_stream_writer\n",
        "import pydantic\n",
        "import requests\n",
        "from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f6563a070bad"
      },
      "source": [
        "### Define schemas\n",
        "\n",
        "Defines all of the schemas, constants, and types required for building the agent."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "4f51460145a6"
      },
      "outputs": [],
      "source": [
        "# Agent config settings\n",
        "\n",
        "\n",
        "class AgentConfig(pydantic.BaseModel):\n",
        "    \"\"\"Configuration settings for the agent, including project, region, and model details.\"\"\"\n",
        "\n",
        "    project: str\n",
        "    \"\"\"The Google Cloud project ID.\"\"\"\n",
        "    region: str\n",
        "    \"\"\"The Google Cloud region where the agent is deployed.\"\"\"\n",
        "    planner_model_name: str\n",
        "    \"\"\"The name of the Gemini model to use for planning.\"\"\"\n",
        "    executor_model_name: str\n",
        "    \"\"\"The name of the Gemini model to use for executing tasks.\"\"\"\n",
        "    reflector_model_name: str\n",
        "    \"\"\"The name of the Gemini model to use for reflecting on the plan and results.\"\"\"\n",
        "\n",
        "\n",
        "# Node names and literal types\n",
        "\n",
        "REFLECTOR_NODE_NAME = \"REFLECTOR\"\n",
        "\"\"\"The name of the reflector node in the LangGraph.\"\"\"\n",
        "ReflectorNodeTargetLiteral = Literal[\"REFLECTOR\"]\n",
        "\"\"\"Literal type for the reflector node target.\"\"\"\n",
        "\n",
        "EXECUTOR_NODE_NAME = \"EXECUTOR\"\n",
        "\"\"\"The name of the executor node in the LangGraph.\"\"\"\n",
        "ExecutorNodeTargetLiteral = Literal[\"EXECUTOR\"]\n",
        "\"\"\"Literal type for the executor node target.\"\"\"\n",
        "\n",
        "PLANNER_NODE_NAME = \"PLANNER\"\n",
        "\"\"\"The name of the planner node in the LangGraph.\"\"\"\n",
        "PlannerNodeTargetLiteral = Literal[\"PLANNER\"]\n",
        "\"\"\"Literal type for the planner node target.\"\"\"\n",
        "\n",
        "POST_PROCESS_NODE_NAME = \"POST_PROCESS\"\n",
        "\"\"\"The name of the post-processing node in the LangGraph.\"\"\"\n",
        "PostProcessNodeTargetLiteral = Literal[\"POST_PROCESS\"]\n",
        "\"\"\"Literal type for the post-processing node target.\"\"\"\n",
        "\n",
        "EndNodeTargetLiteral = Literal[\"__end__\"]\n",
        "\"\"\"Literal type for the end node target.\"\"\"\n",
        "\n",
        "# langgraph models\n",
        "\n",
        "\n",
        "class Task(pydantic.BaseModel):\n",
        "    \"\"\"An individual task with a goal and result.\"\"\"\n",
        "\n",
        "    goal: str = pydantic.Field(\n",
        "        description=\"The description and goal of this step in the plan.\",\n",
        "    )\n",
        "    \"\"\"The description and goal of this step in the plan.\"\"\"\n",
        "\n",
        "    result: str | None = pydantic.Field(\n",
        "        default=None,\n",
        "        description=\"The result of this step determined by the plan executor. Always set this field to None\",\n",
        "    )\n",
        "    \"\"\"The result of this step determined by the plan executor. Always set this field to None.\"\"\"\n",
        "\n",
        "\n",
        "class Plan(pydantic.BaseModel):\n",
        "    \"\"\"A step-by-step sequential plan.\"\"\"\n",
        "\n",
        "    goal: str = pydantic.Field(description=\"High level goal for plan to help user.\")\n",
        "    \"\"\"High level goal for plan to help user.\"\"\"\n",
        "    tasks: list[Task] = pydantic.Field(\n",
        "        description=\"A list of individual tasks that will be executed in sequence before responding to the user. As the task gets more complex, you can add more steps.\",\n",
        "    )\n",
        "    \"\"\"A list of individual tasks that will be executed in sequence before responding to the user. As the task gets more complex, you can add more steps.\"\"\"\n",
        "\n",
        "\n",
        "class Response(pydantic.BaseModel):\n",
        "    \"\"\"Response to send to the user.\"\"\"\n",
        "\n",
        "    response: str\n",
        "    \"\"\"The response message to send to the user.\"\"\"\n",
        "\n",
        "\n",
        "class PlanOrRespond(pydantic.BaseModel):\n",
        "    \"\"\"Action to perform. Either respond to user or generate a research plan.\"\"\"\n",
        "\n",
        "    action: Response | Plan = pydantic.Field(\n",
        "        description=\"The next action can either be a direct response to the user or generate a new plan if you need to think more and use tools.\"\n",
        "    )\n",
        "    \"\"\"The next action can either be a direct response to the user or generate a new plan if you need to think more and use tools.\"\"\"\n",
        "\n",
        "\n",
        "# LangGraph models\n",
        "\n",
        "\n",
        "class Turn(TypedDict, total=False):\n",
        "    \"\"\"\n",
        "    Represents a single turn in a conversation.\n",
        "\n",
        "    Attributes:\n",
        "        id: Unique identifier for the turn.\n",
        "        created_at: Timestamp of when the turn was created.\n",
        "        user_input: The user's input in this turn.\n",
        "        response: The agent's response in this turn, if any.\n",
        "        plan: The agent's last generated plan for this turn, if any.\n",
        "        messages: A list of Gemini content messages associated with this turn.\n",
        "    \"\"\"\n",
        "\n",
        "    id: uuid.UUID\n",
        "    \"\"\"Unique identifier for the turn.\"\"\"\n",
        "\n",
        "    created_at: datetime.datetime\n",
        "    \"\"\"Timestamp of when the turn was created.\"\"\"\n",
        "\n",
        "    user_input: str\n",
        "    \"\"\"The user's input for this turn.\"\"\"\n",
        "\n",
        "    response: str\n",
        "    \"\"\"The agent's response for this turn, if any.\"\"\"\n",
        "\n",
        "    plan: Plan | None\n",
        "    \"\"\"The agent's last generated plan for this turn, if any.\"\"\"\n",
        "\n",
        "    messages: list[genai_types.Content]\n",
        "    \"\"\"List of Gemini Content objects representing the conversation messages in this turn.\"\"\"\n",
        "\n",
        "\n",
        "class GraphSession(TypedDict, total=False):\n",
        "    \"\"\"\n",
        "    Represents the complete state of a conversation session.\n",
        "\n",
        "    Attributes:\n",
        "        id: Unique identifier for the session.\n",
        "        created_at: Timestamp of when the session was created.\n",
        "        current_turn: The current turn in the session, if any.\n",
        "        turns: A list of all turns in the session.\n",
        "    \"\"\"\n",
        "\n",
        "    id: uuid.UUID\n",
        "    \"\"\"Unique identifier for the session.\"\"\"\n",
        "\n",
        "    created_at: datetime.datetime\n",
        "    \"\"\"Timestamp of when the session was created.\"\"\"\n",
        "\n",
        "    current_turn: Turn | None\n",
        "    \"\"\"The current conversation turn.\"\"\"\n",
        "\n",
        "    turns: list[Turn]\n",
        "    \"\"\"List of all conversation turns in the session.\"\"\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "39029fa4f4a1"
      },
      "source": [
        "### Utility Functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "78f47e72fb6b"
      },
      "outputs": [],
      "source": [
        "def is_retryable_error(exception: BaseException) -> bool:\n",
        "    \"\"\"\n",
        "    Determines if a given exception is considered retryable.\n",
        "\n",
        "    This function checks if the provided exception is an API error with a retryable HTTP status code\n",
        "    (429, 502, 503, 504) or a connection error.\n",
        "\n",
        "    Args:\n",
        "        exception: The exception to evaluate.\n",
        "\n",
        "    Returns:\n",
        "        True if the exception is retryable, False otherwise.\n",
        "    \"\"\"\n",
        "\n",
        "    if isinstance(exception, genai_errors.APIError):\n",
        "        return exception.code in [429, 502, 503, 504]\n",
        "    if isinstance(exception, requests.exceptions.ConnectionError):\n",
        "        return True\n",
        "    return False"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "4626e780f09c"
      },
      "outputs": [],
      "source": [
        "def stringify_task(task: Task, include_results: bool = True) -> str:\n",
        "    \"\"\"\n",
        "    Formats a task into a human-readable string.\n",
        "\n",
        "    This function takes a task and converts it into a formatted string,\n",
        "    including the task goal and optionally the task result.\n",
        "\n",
        "    Args:\n",
        "        task (Task): The task.\n",
        "        include_results (bool, optional): Whether to include the task result in the output. Defaults to True.\n",
        "\n",
        "    Returns:\n",
        "        str: The formatted task string.\n",
        "    \"\"\"\n",
        "    output = f\"**Goal**: {task.goal}\"\n",
        "\n",
        "    if include_results:\n",
        "        output += f\"\\n\\n**Result**: {task.result or 'incomplete'}\"\n",
        "\n",
        "    return output"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "9f401796d6b2"
      },
      "outputs": [],
      "source": [
        "def stringify_plan(plan: Plan, include_results: bool = True) -> str:\n",
        "    \"\"\"\n",
        "    Formats an execution plan into a human-readable string.\n",
        "\n",
        "    This function takes an execution plan and converts it into a formatted string,\n",
        "    including the goal and a list of tasks.\n",
        "\n",
        "    Args:\n",
        "        plan (Plan): The execution plan.\n",
        "        include_results (bool, optional): Whether to include task results in the output. Defaults to True.\n",
        "\n",
        "    Returns:\n",
        "        str: The formatted execution plan string.\n",
        "    \"\"\"\n",
        "    tasks_str = \"\\n\\n\".join(\n",
        "        f\"**Task #{idx + 1}**\\n\\n\"\n",
        "        + stringify_task(task, include_results=include_results)\n",
        "        for idx, task in enumerate(plan.tasks)\n",
        "    )\n",
        "\n",
        "    response = f\"**Plan**: {plan.goal}\\n\\n{tasks_str}\"\n",
        "\n",
        "    return response"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eeb64aa6c52f"
      },
      "source": [
        "### Core Agent Operations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a0dc6a0fac37"
      },
      "source": [
        "#### Plan Generator\n",
        "\n",
        "Generates a plan or a direct response based on the current turn and conversation history."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "57e92fda9c7c"
      },
      "outputs": [],
      "source": [
        "@retry(\n",
        "    retry=retry_if_exception(is_retryable_error),\n",
        "    wait=wait_exponential(min=1, max=10),\n",
        "    stop=stop_after_attempt(3),\n",
        "    reraise=True,\n",
        ")\n",
        "async def generate_plan(\n",
        "    current_turn: Turn,\n",
        "    project: str,\n",
        "    region: str,\n",
        "    model_name: str,\n",
        "    history: list[Turn] | None = None,\n",
        ") -> PlanOrRespond:\n",
        "    \"\"\"\n",
        "    Generates a plan or a direct response based on the current turn and conversation history.\n",
        "\n",
        "    This function uses a Gemini model to analyze the user's input and the conversation history\n",
        "    to determine whether to generate a step-by-step plan for further action or to provide a\n",
        "    direct response to the user.\n",
        "\n",
        "    Args:\n",
        "        current_turn: The current turn in the conversation, containing the user's input.\n",
        "        project: The Google Cloud project ID.\n",
        "        region: The Google Cloud region.\n",
        "        model_name: The name of the Gemini model to use.\n",
        "        history: A list of previous turns in the conversation (optional).\n",
        "\n",
        "    Returns:\n",
        "        A PlanOrRespond object, which can either contain a Response object (to respond to the user)\n",
        "        or a Plan object (to generate a new plan).\n",
        "    \"\"\"\n",
        "\n",
        "    history = history or []\n",
        "\n",
        "    client = genai.Client(vertexai=True, project=project, location=region)\n",
        "\n",
        "    contents = [\n",
        "        genai_types.Content(role=role, parts=[genai_types.Part.from_text(text=text)])\n",
        "        for turn in history + [current_turn]\n",
        "        for role, text in (\n",
        "            (\"user\", turn.get(\"user_input\")),\n",
        "            (\"model\", turn.get(\"response\") or \"EMPTY\"),\n",
        "        )\n",
        "    ]\n",
        "\n",
        "    content_response = await client.aio.models.generate_content(\n",
        "        model=model_name,\n",
        "        contents=contents,\n",
        "        config=genai_types.GenerateContentConfig(\n",
        "            response_mime_type=\"application/json\",\n",
        "            response_schema=PlanOrRespond,\n",
        "            system_instruction=\"\"\"\n",
        "# Mission\n",
        "For the given user input, come up with a response to the user or a simple step by step plan.\n",
        "\n",
        "## Choices\n",
        "If you can provide a direct response without executing any sub-tasks, provide a response action.\n",
        "If you need clarification or have follow up questions, provide a response action.\n",
        "If the user input requires research to answer or looking up realtime data, provide a plan action.\n",
        "\n",
        "## Instructions for plans\n",
        "The plan should involve individual tasks, that if executed correctly will yield the correct answer. Do not add any superfluous steps.\n",
        "The result of the final step should be the final answer. Make sure that each step has all the information needed - do not skip steps.\n",
        "None of the steps are allowed to be user-facing, they must all be executed by the research agent with no input from the user.\n",
        "A different responder agent will generate a final response to the user after the researcher executes the plan tasks.\n",
        "Only add steps to the plan that still NEED to be done. Do not return previously done steps as part of the plan.\n",
        "\"\"\".strip(),\n",
        "        ),\n",
        "    )\n",
        "\n",
        "    plan_reflection = PlanOrRespond.model_validate_json(content_response.text)\n",
        "\n",
        "    return plan_reflection"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5894eaaa34a4"
      },
      "source": [
        "#### Plan Executor\n",
        "\n",
        "Executes a given plan step-by-step and yields the results of each task."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "2cff046fae5a"
      },
      "outputs": [],
      "source": [
        "@retry(\n",
        "    retry=retry_if_exception(is_retryable_error),\n",
        "    wait=wait_exponential(min=1, max=10),\n",
        "    stop=stop_after_attempt(3),\n",
        "    reraise=True,\n",
        ")\n",
        "async def execute_plan(\n",
        "    plan: Plan,\n",
        "    project: str,\n",
        "    region: str,\n",
        "    model_name: str,\n",
        ") -> AsyncGenerator[tuple[int, Task], None]:\n",
        "    \"\"\"\n",
        "    Executes a given plan step-by-step and yields the results of each task.\n",
        "\n",
        "    This function iterates through the tasks in a given plan, executes each task using a Gemini model\n",
        "    with Google Search tool enabled, and yields the index and updated task with the result.\n",
        "\n",
        "    Args:\n",
        "        plan: The plan to execute, containing a list of tasks.\n",
        "        project: The Google Cloud project ID.\n",
        "        region: The Google Cloud region.\n",
        "        model_name: The name of the Gemini model to use.\n",
        "\n",
        "    Yields:\n",
        "        An asynchronous generator that yields tuples of (index, task), where index is the task's\n",
        "        position in the plan and task is the updated task with the execution result.\n",
        "    \"\"\"\n",
        "\n",
        "    executed_plan = plan.model_copy(deep=True)\n",
        "\n",
        "    client = genai.Client(vertexai=True, project=project, location=region)\n",
        "\n",
        "    search_tool = genai_types.Tool(google_search=genai_types.GoogleSearch())\n",
        "    system_instruction = \"Your mission is to execute the research goal provided and respond with findings. The result is not provided directly to the user, but instead provided to another agent to summarize findings.\"\n",
        "\n",
        "    for idx, task in enumerate(executed_plan.tasks):\n",
        "        if task.result is not None:\n",
        "            continue\n",
        "\n",
        "        # last task will be missing result. Will fill in from agent response.\n",
        "        all_tasks = executed_plan.tasks[: idx + 1]\n",
        "        all_tasks_string = \"\\n---\\n\".join(\n",
        "            f\"Goal: {task.goal}\\n\\nResult: {task.result or ''}\" for task in all_tasks\n",
        "        )\n",
        "\n",
        "        contents = f\"# Plan\\nHigh Level Goal: {plan.goal}\\n---\\n{all_tasks_string}\"\n",
        "\n",
        "        content_response = await client.aio.models.generate_content(\n",
        "            model=model_name,\n",
        "            contents=contents,\n",
        "            config=genai_types.GenerateContentConfig(\n",
        "                tools=[search_tool],\n",
        "                system_instruction=system_instruction,\n",
        "            ),\n",
        "        )\n",
        "\n",
        "        task.result = content_response.text\n",
        "\n",
        "        yield idx, task"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "064fa4511fc7"
      },
      "source": [
        "#### Plan Reflector\n",
        "\n",
        "Reflects on a user's input and an executed plan to determine the next action (response or new plan)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "050f7917297a"
      },
      "outputs": [],
      "source": [
        "@retry(\n",
        "    retry=retry_if_exception(is_retryable_error),\n",
        "    wait=wait_exponential(min=1, max=10),\n",
        "    stop=stop_after_attempt(3),\n",
        "    reraise=True,\n",
        ")\n",
        "async def reflect_plan(\n",
        "    user_input: str,\n",
        "    executed_plan: Plan,\n",
        "    project: str,\n",
        "    region: str,\n",
        "    model_name: str,\n",
        ") -> PlanOrRespond:\n",
        "    \"\"\"\n",
        "    Reflects on a user's input and an executed plan to determine the next action (response or new plan).\n",
        "\n",
        "    This function uses a Gemini model to analyze the user's last message, the overall goal of the\n",
        "    research agent, and the steps that were executed in the previous plan. Based on this analysis,\n",
        "    it decides whether to generate a direct response to the user or to create a new plan for\n",
        "    further action.\n",
        "\n",
        "    Args:\n",
        "        user_input: The user's most recent input.\n",
        "        executed_plan: The plan that was previously executed.\n",
        "        project: The Google Cloud project ID.\n",
        "        region: The Google Cloud region.\n",
        "        model_name: The name of the Gemini model to use.\n",
        "\n",
        "    Returns:\n",
        "        A PlanOrRespond object, which can either contain a Response object (to respond to the user)\n",
        "        or a Plan object (to generate a new plan).\n",
        "    \"\"\"\n",
        "\n",
        "    client = genai.Client(vertexai=True, project=project, location=region)\n",
        "\n",
        "    system_instructions = \"\"\"\n",
        "# Mission\n",
        "For the given user input, come up with a response to the user or a simple step by step plan.\n",
        "\n",
        "## Choices\n",
        "If you can provide a direct response without executing any sub-tasks, provide a response action.\n",
        "If you need clarification or have follow up questions, provide a response action.\n",
        "If the user input requires multiple steps to answer or looking up realtime data, provide a plan action.\n",
        "\n",
        "## Instructions for plans\n",
        "The plan should involve individual tasks, that if executed correctly will yield the correct answer. Do not add any superfluous steps.\n",
        "The result of the final step should be the final answer. Make sure that each step has all the information needed - do not skip steps.\n",
        "Only add steps to the plan that still NEED to be done. Do not return previously done steps as part of the plan.\n",
        "\"\"\".strip()\n",
        "\n",
        "    contents = f\"\"\"\n",
        "The last user message was:\n",
        "{user_input}\n",
        "\n",
        "The main goal of the research agent was:\n",
        "{executed_plan.goal}\n",
        "\n",
        "The research agent executed the following tasks:\n",
        "{executed_plan.tasks}\n",
        "\"\"\".strip()\n",
        "\n",
        "    content_response = await client.aio.models.generate_content(\n",
        "        model=model_name,\n",
        "        contents=contents,\n",
        "        config=genai_types.GenerateContentConfig(\n",
        "            response_mime_type=\"application/json\",\n",
        "            response_schema=PlanOrRespond,\n",
        "            system_instruction=system_instructions,\n",
        "        ),\n",
        "    )\n",
        "\n",
        "    plan_reflection = PlanOrRespond.model_validate_json(content_response.text)\n",
        "\n",
        "    return plan_reflection"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6209aae9eb2a"
      },
      "source": [
        "#### Test all core functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "a3b0a526fe2b"
      },
      "outputs": [],
      "source": [
        "print(\"-\" * 10, \"Reading user input\", \"-\" * 10, end=\"\\n\\n\")\n",
        "\n",
        "example_user_input = \"research best video games for my nerdy 10yo son\"\n",
        "print(example_user_input, end=\"\\n\\n\")\n",
        "\n",
        "print(\"-\" * 10, \"Generating plan\", \"-\" * 10, end=\"\\n\\n\")\n",
        "\n",
        "example_generated_plan_or_respond = await generate_plan(\n",
        "    Turn(user_input=example_user_input),\n",
        "    project=PROJECT,\n",
        "    region=REGION,\n",
        "    model_name=PLANNER_MODEL_NAME,\n",
        ")\n",
        "example_generated_plan = example_generated_plan_or_respond.action\n",
        "assert isinstance(example_generated_plan, Plan), \"Expected action to be plan\"\n",
        "display(ipd.Markdown(stringify_plan(example_generated_plan)))\n",
        "\n",
        "async for idx, new_task in execute_plan(\n",
        "    plan=example_generated_plan,\n",
        "    project=PROJECT,\n",
        "    region=REGION,\n",
        "    model_name=EXECUTOR_MODEL_NAME,\n",
        "):\n",
        "    example_generated_plan.tasks[idx] = new_task\n",
        "\n",
        "    print(\"-\" * 10, \"Executed task\", \"-\" * 10, end=\"\\n\\n\")\n",
        "\n",
        "    display(ipd.Markdown(f\"**Goal**: {new_task.goal}\\n\\n**Result**: {new_task.result}\"))\n",
        "\n",
        "print(\"-\" * 10, \"Reflection on plan\", \"-\" * 10, end=\"\\n\\n\")\n",
        "\n",
        "example_reflection_plan_or_respond = await reflect_plan(\n",
        "    user_input=example_user_input,\n",
        "    executed_plan=example_generated_plan,\n",
        "    project=PROJECT,\n",
        "    region=REGION,\n",
        "    model_name=REFLECTOR_MODEL_NAME,\n",
        ")\n",
        "example_reflection_response = example_reflection_plan_or_respond.action\n",
        "assert isinstance(\n",
        "    example_reflection_response, Response\n",
        "), \"Expected action to be response\"\n",
        "display(ipd.Markdown(example_reflection_response.response))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a60f2d6417d4"
      },
      "source": [
        "### Nodes"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e4967db1d810"
      },
      "source": [
        "#### Planner Node\n",
        "\n",
        "Generates a plan or a direct response based on the current turn and conversation history."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "01bc6440f912"
      },
      "outputs": [],
      "source": [
        "async def planner_node(\n",
        "    state: GraphSession,\n",
        "    config: lc_config.RunnableConfig,\n",
        ") -> lg_types.Command[Literal[ExecutorNodeTargetLiteral, PostProcessNodeTargetLiteral]]:\n",
        "    \"\"\"\n",
        "    Asynchronously generates a plan or a direct response based on the current conversation state.\n",
        "\n",
        "    This function takes the current conversation state, which includes the user's input and history,\n",
        "    and uses the `generate_plan` function to determine whether to create a plan for further action\n",
        "    or to provide a direct response. It then updates the conversation state and directs the flow\n",
        "    to the appropriate next node (executor or post-processing).\n",
        "\n",
        "    Args:\n",
        "        state: The current state of the conversation session, including user input and history.\n",
        "        config: The LangChain RunnableConfig containing agent-specific configurations.\n",
        "\n",
        "    Returns:\n",
        "        A Command object that specifies the next node to transition to (executor or post-processing)\n",
        "        and the updated conversation state. The state includes the generated plan or response.\n",
        "\n",
        "    Raises:\n",
        "        TypeError: If the plan reflection action is of an unsupported type.\n",
        "    \"\"\"\n",
        "\n",
        "    agent_config = AgentConfig.model_validate(\n",
        "        config[\"configurable\"].get(\"agent_config\", {})\n",
        "    )\n",
        "\n",
        "    stream_writer = get_stream_writer()\n",
        "\n",
        "    current_turn = state.get(\"current_turn\")\n",
        "    assert current_turn is not None, \"current turn must be set\"\n",
        "\n",
        "    user_input = current_turn.get(\"user_input\")\n",
        "    assert user_input is not None, \"user input must be set\"\n",
        "\n",
        "    turns = state.get(\"turns\", [])\n",
        "\n",
        "    plan_reflection = await generate_plan(\n",
        "        current_turn=current_turn,\n",
        "        project=agent_config.project,\n",
        "        region=agent_config.region,\n",
        "        model_name=agent_config.planner_model_name,\n",
        "        history=turns,\n",
        "    )\n",
        "\n",
        "    next_node = None\n",
        "    if isinstance(plan_reflection.action, Plan):\n",
        "        next_node = EXECUTOR_NODE_NAME\n",
        "\n",
        "        # Ensure results aren't set\n",
        "        for task in plan_reflection.action.tasks:\n",
        "            task.result = None\n",
        "\n",
        "        # Set initial plan\n",
        "        current_turn[\"plan\"] = plan_reflection.action\n",
        "        stream_writer({\"plan\": plan_reflection.action.model_dump(mode=\"json\")})\n",
        "\n",
        "    elif isinstance(plan_reflection.action, Response):\n",
        "        next_node = POST_PROCESS_NODE_NAME\n",
        "\n",
        "        # Update turn response\n",
        "        current_turn[\"response\"] = plan_reflection.action.response\n",
        "        stream_writer({\"response\": plan_reflection.action.response})\n",
        "\n",
        "    else:\n",
        "        raise TypeError(\n",
        "            \"Unsupported plan reflection action: %s\", type(plan_reflection.action)\n",
        "        )\n",
        "\n",
        "    return lg_types.Command(\n",
        "        update=GraphSession(current_turn=current_turn),\n",
        "        goto=next_node,\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3b658bf8e3f4"
      },
      "source": [
        "#### Executor Node\n",
        "\n",
        "Executes a given plan step-by-step and yields the results of each task."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "69944a06a9a0"
      },
      "outputs": [],
      "source": [
        "async def executor_node(\n",
        "    state: GraphSession,\n",
        "    config: lc_config.RunnableConfig,\n",
        ") -> lg_types.Command[Literal[ReflectorNodeTargetLiteral]]:\n",
        "    \"\"\"\n",
        "    Asynchronously executes a plan's tasks and updates the conversation state.\n",
        "\n",
        "    This function takes the current conversation state, which includes a plan, and executes each task within that plan.\n",
        "    It utilizes the `execute_plan` function to process each task, updating the plan with the results as it goes.\n",
        "    The function also streams the executed tasks to the user via the stream writer.\n",
        "\n",
        "    Args:\n",
        "        state: The current state of the conversation session, including the plan to execute.\n",
        "        config: The LangChain RunnableConfig containing agent-specific configurations.\n",
        "\n",
        "    Returns:\n",
        "        A Command object that specifies the next node to transition to (reflector) and the\n",
        "        updated conversation state. The state includes the plan with executed tasks.\n",
        "\n",
        "    Raises:\n",
        "        AssertionError: If the plan is not generated before execution.\n",
        "    \"\"\"\n",
        "\n",
        "    agent_config = AgentConfig.model_validate(\n",
        "        config[\"configurable\"].get(\"agent_config\", {})\n",
        "    )\n",
        "\n",
        "    stream_writer = get_stream_writer()\n",
        "\n",
        "    current_turn = state.get(\"current_turn\")\n",
        "    assert current_turn is not None, \"current turn must be set\"\n",
        "\n",
        "    plan = current_turn.get(\"plan\")\n",
        "    assert plan is not None, \"plan must be set\"\n",
        "\n",
        "    async for idx, executed_task in execute_plan(\n",
        "        plan=plan,\n",
        "        project=agent_config.project,\n",
        "        region=agent_config.region,\n",
        "        model_name=agent_config.executor_model_name,\n",
        "    ):\n",
        "        # update state with executed task\n",
        "        plan.tasks[idx] = executed_task\n",
        "\n",
        "        stream_writer({\"executed_task\": executed_task.model_dump(mode=\"json\")})\n",
        "\n",
        "    current_turn[\"plan\"] = plan\n",
        "\n",
        "    return lg_types.Command(\n",
        "        update=GraphSession(current_turn=current_turn),\n",
        "        goto=REFLECTOR_NODE_NAME,\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "78cddeb5cbd2"
      },
      "source": [
        "#### Reflector Node\n",
        "\n",
        "Reflects on a user's input and an executed plan to determine the next action (response or new plan)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "ed761d688327"
      },
      "outputs": [],
      "source": [
        "async def reflector_node(\n",
        "    state: GraphSession,\n",
        "    config: lc_config.RunnableConfig,\n",
        ") -> lg_types.Command[Literal[ExecutorNodeTargetLiteral, PlannerNodeTargetLiteral]]:\n",
        "    \"\"\"\n",
        "    Asynchronously reflects on the executed plan and determines the next action.\n",
        "\n",
        "    This function takes the current conversation state, which includes the executed plan, and uses\n",
        "    the `reflect_plan` function to analyze the results and decide whether to generate a new plan\n",
        "    or provide a direct response. It then updates the conversation state and directs the flow\n",
        "    to the appropriate next node (executor or planner).\n",
        "\n",
        "    Args:\n",
        "        state: The current state of the conversation session, including the executed plan.\n",
        "        config: The LangChain RunnableConfig containing agent-specific configurations.\n",
        "\n",
        "    Returns:\n",
        "        A Command object that specifies the next node to transition to (executor or planner) and the\n",
        "        updated conversation state. The state includes the updated plan or response.\n",
        "\n",
        "    Raises:\n",
        "        AssertionError: If the plan is not generated or not fully executed before reflection.\n",
        "        TypeError: If the plan reflection action is of an unsupported type.\n",
        "    \"\"\"\n",
        "\n",
        "    agent_config = AgentConfig.model_validate(\n",
        "        config[\"configurable\"].get(\"agent_config\", {})\n",
        "    )\n",
        "\n",
        "    stream_writer = get_stream_writer()\n",
        "\n",
        "    current_turn = state.get(\"current_turn\")\n",
        "    assert current_turn is not None, \"current turn must be set\"\n",
        "\n",
        "    user_input = current_turn.get(\"user_input\")\n",
        "    assert user_input is not None, \"user input must be set\"\n",
        "\n",
        "    plan = current_turn.get(\"plan\")\n",
        "    assert plan is not None, \"plan must be set\"\n",
        "\n",
        "    assert all(\n",
        "        task.result is not None for task in plan.tasks\n",
        "    ), \"Must execute each plan task before reflection.\"\n",
        "\n",
        "    plan_reflection = await reflect_plan(\n",
        "        user_input=user_input,\n",
        "        executed_plan=plan,\n",
        "        project=agent_config.project,\n",
        "        region=agent_config.region,\n",
        "        model_name=agent_config.reflector_model_name,\n",
        "    )\n",
        "\n",
        "    next_node = None\n",
        "    if isinstance(plan_reflection.action, Plan):\n",
        "        next_node = EXECUTOR_NODE_NAME\n",
        "\n",
        "        # Ensure results aren't set\n",
        "        for task in plan_reflection.action.tasks:\n",
        "            task.result = None\n",
        "\n",
        "        # Add new tasks from plan reflection\n",
        "        current_turn[\"plan\"].tasks += plan_reflection.action.tasks\n",
        "\n",
        "        stream_writer({\"plan\": current_turn[\"plan\"].model_dump(mode=\"json\")})\n",
        "\n",
        "    elif isinstance(plan_reflection.action, Response):\n",
        "        next_node = POST_PROCESS_NODE_NAME\n",
        "\n",
        "        # Update turn response\n",
        "        current_turn[\"response\"] = plan_reflection.action.response\n",
        "\n",
        "        stream_writer({\"response\": current_turn[\"response\"]})\n",
        "    else:  # never\n",
        "        raise TypeError(\n",
        "            \"Unsupported plan reflection action: %s\", type(plan_reflection.action)\n",
        "        )\n",
        "\n",
        "    return lg_types.Command(\n",
        "        update=GraphSession(current_turn=current_turn),\n",
        "        goto=next_node,\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "00879e69b95f"
      },
      "source": [
        "#### Post-Process Node\n",
        "\n",
        "Add current turn to the history and reset current turn."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "4be2c75208a3"
      },
      "outputs": [],
      "source": [
        "def post_process_node(\n",
        "    state: GraphSession,\n",
        "    config: lc_config.RunnableConfig,\n",
        ") -> lg_types.Command[EndNodeTargetLiteral]:\n",
        "    \"\"\"\n",
        "    Asynchronously invokes the post-processing node to finalize the current conversation turn.\n",
        "\n",
        "    This function takes the current conversation state, validates that the current turn and its response are set,\n",
        "    adds the completed turn to the conversation history, and resets the current turn. This effectively concludes\n",
        "    the processing of the current user input and prepares the session for the next input.\n",
        "\n",
        "    Args:\n",
        "        state: The current state of the conversation session.\n",
        "        config: The LangChain RunnableConfig (unused in this function).\n",
        "\n",
        "    Returns:\n",
        "        A Command object specifying the end of the graph execution and the updated conversation state.\n",
        "    \"\"\"\n",
        "\n",
        "    del config  # unused\n",
        "\n",
        "    current_turn = state.get(\"current_turn\")\n",
        "\n",
        "    assert current_turn is not None, \"Current turn must be set.\"\n",
        "    assert (\n",
        "        current_turn[\"response\"] is not None\n",
        "    ), \"Response from current turn must be set.\"\n",
        "\n",
        "    turns = state.get(\"turns\", []) + [current_turn]\n",
        "\n",
        "    return lg_types.Command(update=GraphSession(current_turn=None, turns=turns))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "740d6e04b1c0"
      },
      "source": [
        "## Compile Task Planner Agent"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "fa296f770e91"
      },
      "outputs": [],
      "source": [
        "def load_graph():\n",
        "    state_graph = graph.StateGraph(GraphSession)\n",
        "\n",
        "    state_graph.add_node(PLANNER_NODE_NAME, planner_node)\n",
        "    state_graph.add_node(EXECUTOR_NODE_NAME, executor_node)\n",
        "    state_graph.add_node(REFLECTOR_NODE_NAME, reflector_node)\n",
        "    state_graph.add_node(POST_PROCESS_NODE_NAME, post_process_node)\n",
        "\n",
        "    state_graph.set_entry_point(PLANNER_NODE_NAME)\n",
        "\n",
        "    return state_graph\n",
        "\n",
        "\n",
        "state_graph = load_graph()\n",
        "\n",
        "checkpointer = memory.MemorySaver()\n",
        "compiled_graph = state_graph.compile(checkpointer=checkpointer)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "683793bb886f"
      },
      "source": [
        "### Visualize Agent Graph"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "73ff616d10c5"
      },
      "outputs": [],
      "source": [
        "png_bytes = compiled_graph.get_graph().draw_mermaid_png()\n",
        "\n",
        "display(ipd.Image(png_bytes))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "381f3368c641"
      },
      "source": [
        "### Wrapper function to stream generation output to notebook"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 29,
      "metadata": {
        "id": "edcf2f3f08aa"
      },
      "outputs": [],
      "source": [
        "async def ask(user_input: str, session: str | None = None):\n",
        "    thread_id = session or uuid.uuid4().hex\n",
        "\n",
        "    agent_config = AgentConfig(\n",
        "        project=PROJECT,\n",
        "        region=REGION,\n",
        "        planner_model_name=PLANNER_MODEL_NAME,\n",
        "        executor_model_name=EXECUTOR_MODEL_NAME,\n",
        "        reflector_model_name=REFLECTOR_MODEL_NAME,\n",
        "    )\n",
        "\n",
        "    current_source = last_source = None\n",
        "    task_idx = 0\n",
        "    all_text = \"\"\n",
        "    async for stream_mode, chunk in compiled_graph.astream(\n",
        "        input={\"current_turn\": {\"user_input\": user_input}},\n",
        "        config={\"configurable\": {\"thread_id\": thread_id, \"agent_config\": agent_config}},\n",
        "        stream_mode=[\"custom\"],\n",
        "    ):\n",
        "        assert isinstance(chunk, dict), \"Expected dictionary chunk\"\n",
        "\n",
        "        text = \"\"\n",
        "\n",
        "        if \"response\" in chunk:\n",
        "            # if no prior text, then no plan was generated\n",
        "            if all_text.strip() == \"\":\n",
        "                text = chunk[\"response\"]\n",
        "            else:\n",
        "                text = \"### Reflection\\n\\n\" + chunk[\"response\"]\n",
        "\n",
        "            current_source = \"response\"\n",
        "\n",
        "        elif \"plan\" in chunk:\n",
        "            plan = Plan.model_validate(chunk[\"plan\"])\n",
        "            plan_string = stringify_plan(plan=plan, include_results=False)\n",
        "            text = f\"### Generated execution plan...\\n\\n{plan_string}\"\n",
        "            current_source = \"plan\"\n",
        "\n",
        "        elif \"executed_task\" in chunk:\n",
        "            task_idx += 1\n",
        "            task = Task.model_validate(chunk[\"executed_task\"])\n",
        "            task_string = stringify_task(task=task, include_results=True)\n",
        "            text = f\"### Executed task #{task_idx}...\\n\\n{task_string}\"\n",
        "            current_source = f\"executed_task_{task_idx}\"\n",
        "\n",
        "        else:\n",
        "            print(\"unhandled chunk case:\", chunk)\n",
        "\n",
        "        if last_source is not None and last_source != current_source:\n",
        "            text = \"\\n\\n---\\n\\n\" + text\n",
        "\n",
        "        last_source = current_source\n",
        "\n",
        "        all_text += text\n",
        "\n",
        "        display(ipd.Markdown(all_text), clear=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3bea54263302"
      },
      "source": [
        "## Test Conversation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 30,
      "metadata": {
        "id": "24574cb0c9ef"
      },
      "outputs": [],
      "source": [
        "session = uuid.uuid4().hex"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 31,
      "metadata": {
        "id": "fec8d9b95e5a"
      },
      "outputs": [],
      "source": [
        "await ask(\"hi\", session=session)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 32,
      "metadata": {
        "id": "94d40bf397e3"
      },
      "outputs": [],
      "source": [
        "await ask(\"can you recommend some video games for my child?\", session=session)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 33,
      "metadata": {
        "id": "566ad868d3cd"
      },
      "outputs": [],
      "source": [
        "await ask(\n",
        "    \"he is 15. he's into medieval history so he might like something related. He plays on his Playstation 5\",\n",
        "    session=session,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 34,
      "metadata": {
        "id": "874c5a60a627"
      },
      "outputs": [],
      "source": [
        "await ask(\n",
        "    \"I don't know anything about gaming. Can you give me more info?\",\n",
        "    session=session,\n",
        ")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "task-planner.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
